package com.adu.music.spider;

import com.adu.music.db.AsyncNamedParameterJdbcTemplate;
import com.adu.music.util.DbUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * @author duchuanchuan
 * @date 2016/12/31
 */
public abstract class AbstractSpider {

    protected final Logger logger = LoggerFactory.getLogger(this.getClass());
    // 标记生产者线程初始化是否完成
    private CountDownLatch initLatch;
    // 标记消费者线程是否全部结束
    private CountDownLatch endLatch;
    // 消费者线程个数
    private int consumerNum;
    // 任务队列容量
    private int queueCapacity;
    // 任务队列
    private BlockingQueue<Object> taskQueue;
    // 异步的操作数据库Template
    protected AsyncNamedParameterJdbcTemplate asyncNamedParameterJdbcTemplate = DbUtils.getAsyncNamedPamaterJdbcTemplate();

    public AbstractSpider(int consumerNum, int queueCapacity) {
        this.consumerNum = consumerNum;
        this.queueCapacity = queueCapacity;
        this.initLatch = new CountDownLatch(1);
        this.endLatch = new CountDownLatch(this.consumerNum);
        this.taskQueue = new LinkedBlockingQueue<>(this.queueCapacity);
    }

    public AbstractSpider() {
        this(4, 500);
    }

    public void setQueueCapacity(int capacity) {
        this.queueCapacity = capacity;
    }

    public void setConsumerNum(int num) {
        this.consumerNum = num;
    }

    /**
     * 初始化任务队列
     */
    protected abstract void initTaskQueue();

    /**
     * 具体实现向任务队列中放任务
     */
    protected abstract void produceTask();

    /**
     * 负责实现消费任务的具体逻辑
     * @param task 要抓取的任务
     */
    protected abstract void consumeTask(Object task);

    private void produce() {
        logger.info("--- 生产者抓取线程启动 ---");
        new Thread(this::produceTask, "producer_thread").start();
    }

    /**
     * 将任务放到任务队列中
     * @param task 任务
     */
    protected void putTaskToQueue(Object task) {
        try {
            this.taskQueue.put(task);
            if (taskQueue.size() > this.queueCapacity * 0.75)
                TimeUnit.SECONDS.sleep(3);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    private void consume() {
        // 在任务队列实例化完成前阻塞
        waitForQueueInitDone();
        for (int i = 0; i < this.consumerNum; i++) {
            final int index = i;
            new Thread(() -> {
                logger.info("消费者线程{}启动", index);
                while (!Thread.interrupted()) {
                    try {
                        if (taskQueue.isEmpty()) break;
                        Object task = taskQueue.take();
                        consumeTask(task);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
                endLatch.countDown();
                logger.info("消费者线程{}退出", index);
            }, "consumer_thread_" + i).start();
        }
        try {
            endLatch.await();
            asyncNamedParameterJdbcTemplate.close();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    /**
     * 表示初始化任务队列完成
     */
    private void initQueueDone() {
        initTaskQueue();
        initLatch.countDown();
    }

    /**
     * 等待初始化队列完成
     */
    private void waitForQueueInitDone() {
        try {
            initLatch.await();
            asyncNamedParameterJdbcTemplate.open();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    public void startTask() {
        // 初始化队列
        initQueueDone();
        // 启动生产者
        produce();
        // 启动消费者
        consume();
    }
}
